

# import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt
# from scipy.stats import ttest_rel
# plt.rcParams['font.family'] = 'Times New Roman'

# # Define dictionary of files
# files = {    
#     # "CIFAR10_AllCNN": r"C:/Temp/Unlearning/Results/Cifar 10 AllCNN/compiled_results_MIAU.csv",
#     # "CIFAR10_ResNet": r"C:/Temp/Unlearning/Results/Cifar 10 Resnet/compiled_results_MIAU.csv",
#     # "CIFAR20_AllCNN": r"C:/Temp/Unlearning/Results/Cifar 20 AllCNN/compiled_results_MIAU.csv",
#     # "CIFAR20_ResNet": r"C:/Temp/Unlearning/Results/Cifar 20 Resnet/compiled_results_MIAU.csv",
#     # "CIFAR10_ViT": r"C:/Temp/Unlearning/Results/Cifar 10 ViT/compiled_results_MIAU.csv",
#     # "MNIST_ResNet": r"C:/Temp/Unlearning/Results/MNIST Resnet/compiled_results_MIAU.csv",
#     # "MNIST_AllCNN": r"C:/Temp/Unlearning/Results/MNIST AllCNN/compiled_results_MIAU.csv",
#     # "MUCAC_ResNet": r"C:/Temp/Unlearning/Results/MUCAC Resnet/compiled_results_MIAU.csv"
    
#      "CIFAR10_ResNet_Underfitted": r"C:/Temp/Unlearning/Results/Underfitted/compiled_results_MIAU.csv",
#     "CIFAR10_ResNet_Overfitted": r"C:/Temp/Unlearning/Results/Overfitted/compiled_results_MIAU.csv"
# }


# comparisons = [
#     ("retrain50", "retrain25"),
#     ("retrain75", "retrain25"),
#     ("retrain75", "retrain50")
# ]

# pval_dict = {}

# for name, path in files.items():
#     try:
#         df = pd.read_csv(path)

#         methods = ['retrain25', 'retrain50', 'retrain75']
#         df_filtered = df[df['unlearning'].isin(methods)]
#         df_pivot = df_filtered.pivot(index='seed', columns='unlearning', values='MIAU')

#         if not all(method in df_pivot.columns for method in methods):
#             continue

#         result = {}
#         for high, low in comparisons:
#             pval = ttest_rel(df_pivot[high], df_pivot[low], alternative='greater').pvalue
#             result[f"{high} > {low}"] = pval

#         pval_dict[name] = result

#     except Exception as e:
#         print(f"Error processing {name}: {e}")

# df_pvals = pd.DataFrame.from_dict(pval_dict, orient='index')

# plt.figure(figsize=(10, len(df_pvals) * 0.6))
# sns.heatmap(df_pvals, annot=True, fmt=".4f", cmap="Reds", cbar_kws={'label': 'p-value'})
# plt.title("One-sided p-values for MIAU comparisons", fontname='Times New Roman')
# plt.xlabel("Comparison", fontname='Times New Roman')
# plt.ylabel("Dataset", fontname='Times New Roman')
# plt.tight_layout()
# plt.savefig(r"C:/Temp/Unlearning/figure_p_value_heatmap_nongeneralized.pdf", dpi=300, bbox_inches='tight')

# plt.show()



import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel, t
import numpy as np

plt.rcParams['font.family'] = 'Times New Roman'

files = {    
    "CIFAR10_AllCNN": r"C:/Temp/Unlearning/Data Appendix/Cifar 10 AllCNN/compiled_results_MIAU.csv", 
    "CIFAR10_ResNet": r"C:/Temp/Unlearning/Data Appendix/Cifar 10 Resnet/compiled_results_MIAU.csv", 
    "CIFAR20_AllCNN": r"C:/Temp/Unlearning/Data Appendix/Cifar 20 AllCNN/compiled_results_MIAU.csv", 
    "CIFAR20_ResNet": r"C:/Temp/Unlearning/ResuData Appendixlts/Cifar 20 Resnet/compiled_results_MIAU.csv", 
    "CIFAR10_ViT":    r"C:/Temp/Unlearning/Data Appendix/Cifar 10 ViT/compiled_results_MIAU.csv", 
    "MNIST_ResNet":   r"C:/Temp/Unlearning/Data Appendix/MNIST Resnet/compiled_results_MIAU.csv", 
    "MNIST_AllCNN":   r"C:/Temp/Unlearning/Data Appendix/MNIST AllCNN/compiled_results_MIAU.csv", 
    "MUCAC_ResNet":   r"C:/Temp/Unlearning/Data Appendix/MUCAC Resnet/compiled_results_MIAU.csv" ,
        "CIFAR10_ResNet_Saliency": r"C:\Temp\Unlearning\Data Appendix\Cifar 10 Resnet Saliency\compiled_results_MIAU.csv"
}

comparisons = [
    ("retrain50", "retrain25"),
    ("retrain75", "retrain25"),
    ("retrain75", "retrain50")
]

pval_dict = {}
stats_rows = []

def cohens_d_paired(x, y):
    """Cohen's d for paired samples."""
    diff = x - y
    return diff.mean() / diff.std(ddof=1)

for name, path in files.items():
    try:
        df = pd.read_csv(path)
        methods = ['retrain25', 'retrain50', 'retrain75']
        df_filtered = df[df['unlearning'].isin(methods)]
        df_pivot = df_filtered.pivot(index='seed', columns='unlearning', values='MIAU')

        if not all(method in df_pivot.columns for method in methods):
            continue

        pvals = {}
        for high, low in comparisons:
            x = df_pivot[high].dropna()
            y = df_pivot[low].dropna()
            if len(x) != len(y) or len(x) < 2:
                continue

            t_res = ttest_rel(x, y, alternative='greater')
            pval = t_res.pvalue

            # Effect size (paired Cohen's d)
            d_val = cohens_d_paired(x, y)

            # 95% CI for mean paired difference
            diff = x - y
            n = len(diff)
            mean_diff = diff.mean()
            se = diff.std(ddof=1) / np.sqrt(n)
            t_crit = t.ppf(0.975, df=n - 1)  # two-sided 95% CI
            ci_low = mean_diff - t_crit * se
            ci_high = mean_diff + t_crit * se

            pvals[f"{high}>{low}"] = pval
            stats_rows.append({
                "Dataset": name,
                "Comparison": f"{high} > {low}",
                "n": n,
                "Mean diff": mean_diff,
                "95% CI lower": ci_low,
                "95% CI upper": ci_high,
                "Cohen's d": d_val,
                "p-value": pval
            })

        pval_dict[name] = pvals

    except Exception as e:
        print(f"Error processing {name}: {e}")

# Detailed stats printed only
df_stats = pd.DataFrame(stats_rows)
print("\nDetailed statistics (mean diff, 95% CI, Cohen's d, p-value):\n")
print(df_stats.to_string(index=False))

# Heatmap of p-values
df_pvals = pd.DataFrame.from_dict(pval_dict, orient='index')
plt.figure(figsize=(10, max(4, len(df_pvals) * 0.6)))
sns.heatmap(df_pvals, annot=True, fmt=".4f", cmap="Reds",
            cbar_kws={'label': 'p-value'})
plt.title("One-sided p-values for MIAU comparisons", fontname='Times New Roman')
plt.xlabel("Comparison", fontname='Times New Roman')
plt.ylabel("Dataset", fontname='Times New Roman')
plt.tight_layout()
plt.savefig(r"C:/Temp/Unlearning/figure_p_value_heatmap.pdf",
            dpi=300, bbox_inches='tight')
plt.show()
